# ---------- Basic Values that do not change -------
zeta_upper = 10
q_solve_range = c(0.01,3.2)
K0=1  # Initial simulation of capital is always 1
# Grid for zeta and w
zeta_vec = seq(0,1,length.out=200)
w_grid = seq(0.0001,0.9999,length.out = 100)


# ******** Generate calibration moments *******
calibration_moments <- function(par, print_q_residual=FALSE){
  
AH = par$AH
AL = AH / 3.7  # note: this is the ratio target in the data.
theta = par$theta
beta = par$beta
lambda_F = max(par$lambda_F,0.1)   # just to avoid the search algorithm to go negative. This will not be binding.
d_bar = max(par$d_bar,0)  # just to avoid the search algorithm to go negative. This will not be binding.
g_bar = par$g_bar
avg_zeta = par$avg_zeta
r = par$r
delta = par$delta
lambda = par$lambda
eta_H = par$eta_H
eta_L = par$eta_L
w0 = par$w0
eta_tilde = par$eta_tilde

# -------- Model Basics -----------

# Distribution of the liquidity shock 
F_fun <- function(x){ 
  return( pexp(x, rate=lambda_F) )
}
f_fun <- function(x){
  return( dexp(x, rate=lambda_F) )
}
f_fun_inverse <- function(y){
  return( -log(y/lambda_F)/lambda_F )
}
# MyPlot( zeta_grid, F_fun(zeta_grid) )

# Investment functions 
Phi_fun <- function(x){
  return( 1/2*theta*x^2 + x )
}
iota_bar_fun <- function(q){
  return( (q-1)/theta )
}

# Distribution of zeta for H and L types
F_zeta_fun <- function(zeta, type="H"){
  # return(punif(zeta, max=2*zeta_avg))
  return( pexp(zeta, rate=1/avg_zeta) )
}
f_zeta_fun <- function(zeta, type="H"){
  # return(dunif(zeta, max=2*zeta_avg))
  return( dexp(zeta, rate=1/avg_zeta) )
}

normal_investment=TRUE
first_best=FALSE
plot=FALSE
with_banks = FALSE
show_warning = FALSE

# Initialization values
qL = beta+0.1
qH = 2*qL

# ------------------ Model Components --------------------------
# Obj of the optimization problem
obj_baseline_fun<- function(l, zeta, q){  # liquidity and investment
  return( F_fun(zeta+l) * q - l  )
}

# ** Solve break-even x^* and impose that to the problem. **
solve_break_even <- function( input ){
  F_multiplier = input[1]
  zeta = input[2]
  # Format:  F_fun(zeta + x) * F_multiplier = x
  # Restriction: F_multiplier needs to be positive. 
  upper=zeta_upper
  if( F_fun(zeta+upper)*F_multiplier - upper >0 ){
    upper=1000
  }
  result = uniroot( function(x) F_fun(zeta+x)*F_multiplier - x,  lower=0, upper=upper, tol=10^(-10) )
  return( data.frame(x_star = result$root, residual=result$f.root ) )
}
# Then fit a very accurate function that immediately gives x_star based on F_multiplier and zeta. 
N_grid_ponits = 30
F_multiplier_grid = seq(0, 3, length.out=N_grid_ponits)
zeta_grid = seq(0, max(zeta_upper+10, 20)  , length.out=N_grid_ponits)
MatrixResults <- matrix(0, nrow=length(F_multiplier_grid), ncol=length(zeta_grid))
for(iter_row in 1:length(F_multiplier_grid)){
  # if(iter_row%%10==0){
  #   print( paste0("Progress of fitting d_hat_fun: ",  round(iter_row/length(F_multiplier_grid)*100,1), "%" )   )
  # }
  for(iter_col in 1:length(zeta_grid)){
    MatrixResults[iter_row, iter_col] = solve_break_even( c( F_multiplier_grid[iter_row], zeta_grid[iter_col] ) )$x_star
  }
}
# Below is a convenient summary of the solution to the endogenous debt limit equation F_fun(zeta+ d_hat )*F_multiplier - d_hat = 0
d_hat_fun <- function(F_mutiplier, zeta){
  return( interp2(F_multiplier_grid, zeta_grid, MatrixResults, xp=F_mutiplier, yp=zeta, method = c("linear", "nearest", "constant")) )
}

# **** An "explicit" solution to individual optimization
solve_individual_optimal_faster <- function(q, d_bar, g_bar, Moral_Hazard=TRUE, endogenous_debt_limit=TRUE){
  if(!Moral_Hazard){
    # If we remove moral hazard, then equivalently we have beta=0
    beta = 0
  }
  # Then solve for zeta_bar and zeta_cutoff. 
  if( f_fun(0)*q - 1 < 0 ){
    zeta_bar = 0
  }else{
    sol_baseline = uniroot( function(x) f_fun(x)*q - 1,  c(0,10) )  
    zeta_bar = sol_baseline$root
  }
  x1_fun <- function(zeta){
    return( pmax(pmin(zeta_bar-zeta, d_bar + g_bar), 0) )
  }
  endogenous_private_debt_limit_fun <- function(zeta){  # Given g_bar
    if(endogenous_debt_limit){
      return(  d_hat_fun(pmax(rep(q-beta,length(zeta)),0), zeta+g_bar)  )
    }else{
      return( rep(d_bar, length(zeta))  )  # Note: when d_hat = d_bar, effectively there is no endogenous debt limit.
    }
  }
  x2_fun <- function(zeta){
    if(Moral_Hazard){
      return(  pmin(endogenous_private_debt_limit_fun(zeta), d_bar) +  g_bar )
    }else{
      # Without moral hazard, the strategy is to declare bankruptcy and do not borrow at all
      return(0*zeta)
    }
  }
  obj1_fun <- function(zeta){  # This is the firm profit when it honors payment. 
    return(  obj_baseline_fun( x1_fun(zeta), zeta, q ) )
  } 
  obj2_fun <- function(zeta){  # This is the firm profit when it strategically defaults.
    return( F_fun(x2_fun(zeta)+zeta)*beta )
  }
  # Next obtain the cutoff.
  zeta_cutoff_residual_fun <- function(zeta){
    return( q - x1_fun(zeta)/F_fun(x1_fun(zeta)+zeta) - beta )
  }
  if( all( zeta_cutoff_residual_fun(c(0.000001,zeta_upper))>0 ) ){  # In this case, no default is always better
    zeta_cutoff = 0
  }else if( all( zeta_cutoff_residual_fun(c(0.000001,zeta_upper))<0  ) ){ # In this case, default is always better
    zeta_cutoff = zeta_upper
  }else{
    zeta_cutoff_result = uniroot( zeta_cutoff_residual_fun, interval = c(0.000001, zeta_upper) )
    zeta_cutoff = zeta_cutoff_result$root
    if( abs(zeta_cutoff_residual_fun(zeta_cutoff))>10^(-6) ){
      print( "solution is inaccurate for zeta cutoff!" )
    }
  }
  C_indicator <- function(zeta){ 
    # indicator of wheather to continue the firm.  
    return( zeta >= zeta_cutoff )
  }
  l_fun <- function(zeta){
    # Optimal solution
    return( x1_fun(zeta)*C_indicator(zeta) + x2_fun(zeta)*(1-C_indicator(zeta))  )
  }
  obj_optimal_fun <- function(zeta){
    return( obj1_fun(zeta)*C_indicator(zeta) + obj2_fun(zeta)*(1-C_indicator(zeta)) )
  }
  interest_rate_fun <- function(zeta){
    return( 1/F_fun(zeta+l_fun(zeta)) - 1 )
  }
  return( 
    list(zeta_bar=zeta_bar, zeta_cutoff=zeta_cutoff, l_fun=l_fun, 
            x1_fun=x1_fun, x2_fun=x2_fun, C_indicator=C_indicator, 
            obj1_fun = obj1_fun, obj2_fun = obj2_fun, obj_optimal_fun=obj_optimal_fun, 
            interest_rate_fun=interest_rate_fun, endogenous_private_debt_limit_fun=endogenous_private_debt_limit_fun, 
            zeta_cutoff_residual_fun=zeta_cutoff_residual_fun
         )
  )
}

# *** Next, calculate the post-crisis capital stock. ***
post_crisis_fraction_of_capital <- function(q, d_bar, g_bar, type="H"){
  # Note: q is the capital price, and F_fun is the distribution of the zeta shocks. 
  result = solve_individual_optimal_faster(q, d_bar, g_bar)
  l_fun = result$l_fun
  result_no_moral_hazard = solve_individual_optimal_faster(q, d_bar, g_bar, Moral_Hazard = FALSE)
  l_fun_no_moral_hazard = result_no_moral_hazard$l_fun
  C_indicator = result$C_indicator
  obj_optimal_fun = result$obj_optimal_fun
  
  total_borrowing = integrate( function(zeta) l_fun(zeta)*f_zeta_fun(zeta,type), lower=0, upper=zeta_upper )$value
  no_moral_hazard_borrowing = integrate( function(zeta) l_fun_no_moral_hazard(zeta)*f_zeta_fun(zeta,type), lower=0, upper=zeta_upper )$value
  over_borrowing = total_borrowing - no_moral_hazard_borrowing
  
  frac_over_borrowing = over_borrowing/total_borrowing
  kappa = integrate( function(zeta) F_fun(zeta+l_fun(zeta))*f_zeta_fun(zeta,type), lower=0, upper=zeta_upper )$value 
  # Calculate creditor losses. 
  creditor_profit_total = integrate( function(zeta) (F_fun(zeta+l_fun(zeta))*(q-beta) - l_fun(zeta) )*f_zeta_fun(zeta,type) , lower=0, upper=result$zeta_cutoff )$value
  # note: creditors only make profit or loss in the case of default. Otherwise break-even, even in the presence of government financing.
  creditor_contribution = integrate( function(zeta) l_fun(zeta)*f_zeta_fun(zeta,type), lower=0, upper=zeta_upper )$value
  creditor_profit_frac = creditor_profit_total/creditor_contribution
  # Private-sector financing. 
  govt_sector_financing = integrate( function(zeta)  (l_fun(zeta) - pmin( d_bar, l_fun(zeta) )) *f_zeta_fun(zeta,type) , lower=result$zeta_cutoff, upper=zeta_upper, rel.tol = 1e-10)$value +
    integrate( function(zeta)  g_bar * f_zeta_fun(zeta,type) , lower=0, upper=result$zeta_cutoff )$value
  private_sector_financing = total_borrowing - govt_sector_financing
  if(result$zeta_cutoff==zeta_upper){
    endogenous_debt_limit_avg = NA
  }else{
    endogenous_debt_limit_avg = integrate( function(zeta) result$endogenous_private_debt_limit_fun(zeta)*f_zeta_fun(zeta,type), lower=0, upper=result$zeta_cutoff )$value  / integrate( function(zeta) f_zeta_fun(zeta,type), lower=0, upper=result$zeta_cutoff )$value
  }
  return( list(l_fun=l_fun,l_fun_no_moral_hazard=l_fun_no_moral_hazard,C_indicator=C_indicator,kappa=kappa, over_borrowing=over_borrowing, total_borrowing=total_borrowing, frac_over_borrowing=frac_over_borrowing, 
               creditor_profit_total=creditor_profit_total, creditor_contribution=creditor_contribution, creditor_profit_frac=creditor_profit_frac, zeta_cutoff=result$zeta_cutoff, obj_optimal_fun=obj_optimal_fun,
               endogenous_debt_limit_avg=endogenous_debt_limit_avg, govt_sector_financing=govt_sector_financing, private_sector_financing=private_sector_financing)  )
}


# *** Solve Capital Price q ***
# DIFFERENT from without type transition
# print( "Solving qH and qL ..." )
# Need to solve them together rather than individually
q_residual <- function(q_vec, d_bar, g_bar){
  qH = q_vec[1]
  qL = q_vec[2]
  iH = iota_bar_fun(qH)
  iL = iota_bar_fun(qL)
  resultH = solve_individual_optimal_faster(qH, d_bar, g_bar)
  resultL = solve_individual_optimal_faster(qL, d_bar, g_bar)
  # Important: this crisis profit has to differentiate two regions of the solution. 
  crisis_profit_H = integrate( function(zeta) f_zeta_fun(zeta,"H") * ( resultH$obj_optimal_fun(zeta) - qH ), lower = 0, upper=zeta_upper  )$value
  crisis_profit_L = integrate( function(zeta) f_zeta_fun(zeta,"L") * ( resultL$obj_optimal_fun(zeta) - qL ), lower = 0, upper=zeta_upper  )$value
  residual_H = AH - Phi_fun(iH) + iH*qH - (delta+r)*qH + eta_tilde*(qL-qH) + lambda*crisis_profit_H 
  residual_L = AL - Phi_fun(iL) + iL*qL - (delta+r)*qL + eta_tilde*(qH-qL) + lambda*crisis_profit_L
  return( c(residual_H, residual_L) )
}
result = fsolve(function(q_vec) q_residual(q_vec, d_bar, g_bar), c(2,0.6)  )
qH = result$x[1]
qL = result$x[2]

if( max(abs(q_residual(result$x, d_bar, g_bar))) > 10^(-6)  ){
  print( "qH and qL are not accurately solved!" )
}
q_init_vec = c(qH,qL)

# *** Solve Equilibrium ***
solve_equilibrium <- function(d_bar, g_bar){
  result = fsolve( function(q_vec) q_residual(q_vec, d_bar, g_bar), q_init_vec )
  qH = result$x[1]
  qL = result$x[2]
  result_qH = solve_individual_optimal_faster(qH, d_bar, g_bar)
  result_qL = solve_individual_optimal_faster(qL, d_bar, g_bar)
  result_qH_extra = post_crisis_fraction_of_capital(qH, d_bar, g_bar, type="H")
  result_qL_extra = post_crisis_fraction_of_capital(qL, d_bar, g_bar, type="L")
  result_qH$kappa = result_qH_extra$kappa
  result_qL$kappa = result_qL_extra$kappa
  return( list(d_bar=d_bar, g_bar=g_bar, qH=qH, qL=qL, iH=iota_bar_fun(qH), iL=iota_bar_fun(qL), 
               kappaH = result_qH_extra$kappa,  kappaL =  result_qL_extra$kappa, 
               total_borrowing_H = result_qH_extra$total_borrowing, total_borrowing_L = result_qL_extra$total_borrowing,
               frac_over_borrowing_H = result_qH_extra$frac_over_borrowing, frac_over_borrowing_L = result_qL_extra$frac_over_borrowing,
               creditor_profit_frac_H = result_qH_extra$creditor_profit_frac, creditor_profit_frac_L = result_qL_extra$creditor_profit_frac,
               zeta__H =  result_qH_extra$zeta_cutoff, zeta__L = result_qL_extra$zeta_cutoff,
               result_qH=result_qH, result_qL=result_qL, result_qH_extra=result_qH_extra, result_qL_extra=result_qL_extra,  zeta_bar_H = result_qH$zeta_bar,  zeta_bar_L = result_qL$zeta_bar )  )
}
equi_sol = solve_equilibrium(d_bar, g_bar)
equi_sol_no_govt = solve_equilibrium(d_bar, g_bar=0)

# *** find capital quality drift ***
muw_fun <- function(iH, iL, w){
  muw = w*(1-w)*( iH-iL + eta_H/w - eta_L/(1-w)  ) 
  return(muw)
}
Deltaw_fun <- function(kappaH, kappaL, w){
  # Note: kappaH is the H capital leftover as a fraction of pre-crisis H capital
  Deltaw =(1/( w + (1-w)*kappaL/kappaH ) - 1) * w
  return(Deltaw)
}
muK_fun <- function(iH, iL, w){
  muK = (-delta + w*iH + (1-w)*iL + eta_H + eta_L )
  return(muK)
}
DeltaK_fun <- function(kappaH, kappaL, w){
  DeltaK = (w*kappaH + (1-w)*kappaL) - 1
  return(DeltaK)
}
Delta_output_fun <- function(kappaH, kappaL, w){
  output_before = w*AH + (1-w)*AL
  K_before = 1
  w_post = w + Deltaw_fun(kappaH, kappaL, w)   # post-crisis w
  K_post = K_before * ( 1 + DeltaK_fun(kappaH, kappaL, w) )
  output_after = (w_post*AH+(1-w_post)*AL)*K_post
  return( (output_after-output_before)/output_before  )
}
investment_cost_fun <- function(iH, iL, w){
  return( w*Phi_fun(iH) + (1-w)*Phi_fun(iL) )
}
l_dN_fun <- function(l_fun){
  l_dN = integrate( function(zeta) l_fun(zeta)*f_zeta_fun(zeta), lower=0, upper=zeta_upper )$value
  return(l_dN)
}

dt = 1/12   # One dt is a month
burn_N = 1000
Simulate_model <- function(sim_N, d_bar, g_bar, seed=1){
  # Note: if seed=0, then turn off randomness
  if(seed>0){
    set.seed(seed)
    dN_burn = runif(burn_N)<lambda*dt
    dN_vec = runif(sim_N)<lambda*dt
  }else{
    dN_burn = rep(0, burn_N)
    dN_vec = rep(0, sim_N)
  }
  equi_sol = solve_equilibrium(d_bar, g_bar)
  T_vec = seq(1,sim_N)*dt
  TB = data.table(t=T_vec)
  w = 0.5; 
  for(iter in 1:burn_N){
    w = w + muw_fun(equi_sol$iH, equi_sol$iL, w)*dt + Deltaw_fun(equi_sol$kappaH, equi_sol$kappaL, w)*dN_burn[iter]
  }
  w_vec = rep(w, sim_N)
  K_vec = rep(w, sim_N)
  for(iter in 1:(sim_N-1)){
    w = w_vec[iter]
    K = K_vec[iter]
    w_vec[iter+1] = w + muw_fun(equi_sol$iH, equi_sol$iL, w)*dt + Deltaw_fun(equi_sol$kappaH, equi_sol$kappaL, w)*dN_vec[iter]
    K_vec[iter+1] = K + K*( muK_fun(equi_sol$iH, equi_sol$iL, w)*dt + DeltaK_fun(equi_sol$kappaH, equi_sol$kappaL, w)*dN_vec[iter]   )
  }
  TB[,w:=w_vec]
  TB[,K:=K_vec]
  return(TB)
}

simulation_fun <- function(equi_sol, years_sim=50, seed=1,  dt=1, dN_vec=NA, w_init=w0){
  # Note: par has to be after being processed by "solve_for_equilibrium"
  set.seed(seed)
  iH=equi_sol$iH
  iL=equi_sol$iL
  kappaH = equi_sol$kappaH
  kappaL = equi_sol$kappaL
  qH = equi_sol$qH
  qL = equi_sol$qL
  N_sim = round(years_sim/dt)
  if(sum(!is.na(dN_vec))==0){
    dN_vec = as.numeric(runif(N_sim)<lambda*dt) 
  }
  w_vec = rep(w_init, N_sim)
  K_vec = rep(K0, N_sim)
  
  for(iter in 2:N_sim){
    w = w_vec[iter-1]
    K = K_vec[iter-1]
    w_vec[iter] = w + muw_fun(iH,iL,w)*dt + Deltaw_fun(kappaH,kappaL,w)*dN_vec[iter-1] # The timing here is important for the calculations of welfare below. 
    K_vec[iter] = K * ( 1 + muK_fun(iH,iL,w)*dt +  DeltaK_fun(kappaH,kappaL,w)*dN_vec[iter-1] )
  }
  TB_sim = as.data.table(data.frame( T=(1:N_sim)*dt,  w=w_vec, K=K_vec, dN=dN_vec))
  
  TB_sim[  , output_per_K:= AH*w+AL*(1-w) ]  # Per unit of capital productivity
  TB_sim[  , GDP:= output_per_K*K ]  # Per unit of capital
  TB_sim[  , q := qH*w + qL*(1-w) ]
  
  TB_sim[  , normal_investment_per_K:= w*iH+(1-w)*iL ]  # Note: translate lumpy investment into flow-based ones. 
  TB_sim[  , crisis_investment_per_K:= 0 ]  # These are multiplying the shocks for the next period. 
  TB_sim[  , normal_investment_to_output:= normal_investment_per_K/output_per_K ]
  TB_sim[  , normal_consumption_per_K := output_per_K-normal_investment_per_K ]
  TB_sim[  , normal_consumption_to_output := normal_consumption_per_K/output_per_K ]
  TB_sim[  , cumulative_discounted_consumption := cumsum((normal_consumption_per_K*K - crisis_investment_per_K*K)*exp(-r * T))  ]
  TB_sim[  , KH:= w*K ]
  TB_sim[  , KL:= (1-w)*K ]
  return(TB_sim)
}

# A function that evaluates output drop for each g_bar and evaluation point w_evaluation
output_drop_and_gbar <- function(g_bar, w_evaluation, qH = 2*beta+0.1, qL= beta+0.1, tol=10^(-6)){
  resultH = solve_individual_optimal_faster(qH, d_bar, g_bar)
  resultL = solve_individual_optimal_faster(qL, d_bar, g_bar)
  kappaH = integrate( function(zeta) F_fun(zeta+resultH$l_fun(zeta))*f_zeta_fun(zeta,"H"), lower=0, upper=zeta_upper, rel.tol=tol )$value 
  kappaL = integrate( function(zeta) F_fun(zeta+resultL$l_fun(zeta))*f_zeta_fun(zeta,"L"), lower=0, upper=zeta_upper, rel.tol=tol )$value 
  return( Delta_output_fun(kappaH, kappaL, w_evaluation) )
}

# ----------------------- Model Moment Functions -----------------------
solve_other_moments <- function(equi_sol, sim_years=10000){
  result_qH = equi_sol$result_qH
  result_qL = equi_sol$result_qL
  # First, calculate the average creditor recovery rate
  numeratorL = integrate( function(zeta) F_fun(zeta+result_qL$l_fun(zeta))*(equi_sol$qL-beta)/result_qL$l_fun(zeta)*f_zeta_fun(zeta,type="L"), lower=0, upper=result_qL$zeta_cutoff )
  massL = integrate( function(zeta) ((1-F_fun(zeta+result_qL$l_fun(zeta))) + F_fun(zeta+result_qL$l_fun(zeta))*(1-result_qL$C_indicator(zeta)))*f_zeta_fun(zeta,type="L"), lower=0, upper=result_qL$zeta_bar )
  numeratorH = integrate( function(zeta) F_fun(zeta+result_qH$l_fun(zeta))*(equi_sol$qH-beta)/result_qH$l_fun(zeta)*f_zeta_fun(zeta,type="H"), lower=0, upper=result_qH$zeta_cutoff )
  massH = integrate( function(zeta) ((1-F_fun(zeta+result_qH$l_fun(zeta))) + F_fun(zeta+result_qH$l_fun(zeta))*(1-result_qH$C_indicator(zeta)))*f_zeta_fun(zeta,type="H"), lower=0, upper=result_qH$zeta_bar )
  creditor_recovery_rate_fun <- function(w){
    return(   (w*numeratorL$value + (1-w)*numeratorH$value) / (w*massL$value + (1-w)*massH$value)   )
  }
  TB = simulation_fun(equi_sol, years_sim = sim_years)  
  w_avg = mean(TB$w)
  w_median = median(TB$w)
  TB_stead_state = simulation_fun(equi_sol, years_sim = 100, dN_vec=rep(0,sim_years))  
  w_steady = TB_stead_state$w[nrow(TB_stead_state)]
  # w_avg = 0.5
  avg_creditor_recovery = creditor_recovery_rate_fun(w_avg)
  # Next, calculate the average output drop percentage. 
  avg_output_drop = Delta_output_fun(result_qH$kappa, result_qL$kappa, w_avg)
  equi_sol$w_avg = w_avg
  equi_sol$w_median = w_median
  equi_sol$w_steady = w_steady
  equi_sol$avg_creditor_recovery = avg_creditor_recovery
  equi_sol$avg_output_drop = avg_output_drop
  # Calculate the average market to book ratio, i.e., average q 
  equi_sol$avg_q = equi_sol$qH*w_avg + equi_sol$qL*(1-w_avg)
  # Then the average investment to capital ratio 
  equi_sol$avg_investment_to_K = equi_sol$iH*w_avg + equi_sol$iL*(1-w_avg)
  # Average investment to GDP
  equi_sol$avg_investment_to_GDP = equi_sol$iH/AH*w_avg + equi_sol$iL/AL*(1-w_avg)
  # Average productivity 
  equi_sol$avg_productivity = AH*w_avg + AL*(1-w_avg)
  # Total lending 
  equi_sol$total_financing = w_avg*equi_sol$result_qH_extra$total_borrowing + (1-w_avg)*equi_sol$result_qL_extra$total_borrowing
  # Total takeup of government financing
  equi_sol$total_takeup = w_avg*equi_sol$result_qH_extra$govt_sector_financing + (1-w_avg)*equi_sol$result_qL_extra$govt_sector_financing
  total_take_up_fun <- function(w){
    return( w*equi_sol$result_qH_extra$govt_sector_financing + (1-w)*equi_sol$result_qL_extra$govt_sector_financing )
  }
  # Takeup ratio
  equi_sol$takeup_ratio = equi_sol$total_takeup / equi_sol$g_bar
  # Private financing
  equi_sol$private_financing = equi_sol$total_financing - equi_sol$total_takeup
  # average capital growth rate 
  equi_sol$avg_muw = muw_fun( equi_sol$iH, equi_sol$iL, w_avg )
  equi_sol$avg_muK = muK_fun( equi_sol$iH, equi_sol$iL, w_avg )
  # bankruptcy rate 
  equi_sol$liquidity_bankruptcy_L =  integrate( function(zeta) (1-F_fun(zeta+result_qL$l_fun(zeta)))*f_zeta_fun(zeta,type="L"), lower=0, upper=result_qL$zeta_bar )$value
  equi_sol$liquidity_bankruptcy_H =  integrate( function(zeta) (1-F_fun(zeta+result_qH$l_fun(zeta)))*f_zeta_fun(zeta,type="H"), lower=0, upper=result_qH$zeta_bar )$value
  equi_sol$strategic_bankruptcy_L =  integrate( function(zeta) F_fun(zeta+result_qL$l_fun(zeta))*(1-result_qL$C_indicator(zeta))*f_zeta_fun(zeta,type="L"), lower=0, upper=result_qL$zeta_bar )$value
  equi_sol$strategic_bankruptcy_H =  integrate( function(zeta) F_fun(zeta+result_qH$l_fun(zeta))*(1-result_qH$C_indicator(zeta))*f_zeta_fun(zeta,type="H"), lower=0, upper=result_qH$zeta_bar )$value
  equi_sol$liquidity_bankruptcy_total = w_avg*equi_sol$liquidity_bankruptcy_H + (1-w_avg)*equi_sol$liquidity_bankruptcy_L
  equi_sol$strategic_bankruptcy_total = w_avg*equi_sol$strategic_bankruptcy_H + (1-w_avg)*equi_sol$strategic_bankruptcy_L
  equi_sol$bankruptcy_total = equi_sol$liquidity_bankruptcy_total + equi_sol$strategic_bankruptcy_total
  equi_sol$total_take_up_fun = total_take_up_fun
  equi_sol$takeup_to_GDP_ratio = equi_sol$total_takeup / equi_sol$avg_productivity
  # finally, the volatility of TFP process.
  Q = matrix( c(-eta_tilde,eta_tilde, eta_tilde, -eta_tilde), 2, 2 )
  p = expm(Q)[1,1]
  TFP_vol_sd_over_mean = (w_avg/AH + (1-w_avg)/AL) * (AH-AL)*sqrt( p*(1-p) )
  # eta_vec = seq(0,0.2,length.out=100)
  # TFP_vol_sd_over_mean_vec = (w_avg/AH + (1-w_avg)/AL)*(AH-AL)*sqrt( exp(-eta_vec)*(1-exp(-eta_vec)) ) / w_avg
  # MyPlot(eta_vec, TFP_vol_sd_over_mean_vec)
  
  # Summarize key moments that are used to estimate parameters
  key_moments = data.frame(investment_to_K=equi_sol$avg_investment_to_K, avg_GDP_drop=equi_sol$avg_output_drop, creditor_recover=equi_sol$avg_creditor_recovery, avg_productivity=equi_sol$avg_productivity, TFP_vol_sd_over_mean=TFP_vol_sd_over_mean  )
  # Other useful moments (not used for estimating parameters)
  other_moments = data.frame(muK=equi_sol$avg_muK,takeup_raio=equi_sol$takeup_ratio,liquidity_bankruptcy=equi_sol$liquidity_bankruptcy_total,strategic_bankruptcy=equi_sol$strategic_bankruptcy_total, avg_q=equi_sol$avg_q, w_avg=equi_sol$w_avg)
  equi_sol$key_moments = key_moments
  equi_sol$other_moments = other_moments
  return(equi_sol)
}

# Solve for the Welfare function
solve_welfare <- function( w_grid = seq(0.0001,0.9999,length.out = 100), d_bar, g_bar, shutoff_slippery_slope=FALSE, W_vec_init=NA ){
  equi_sol = solve_equilibrium(d_bar, g_bar)
  equi_sol_no_govt = solve_equilibrium(d_bar, g_bar=0)
  qH = equi_sol$qH
  qL = equi_sol$qL
  iH = equi_sol$iH
  iL = equi_sol$iL
  # Note: regardless of the liquidity survival, firms borrow, and we have to account for all firm borrowing. 
  l_dN = w_grid*l_dN_fun(equi_sol$result_qH$l_fun) + (1-w_grid)*l_dN_fun(equi_sol$result_qL$l_fun)
  A = w_grid*AH + (1-w_grid)*AL
  I = investment_cost_fun(iH, iL, w_grid)
  muw = muw_fun(iH,iL,w_grid)
  muK = muK_fun(iH,iL,w_grid)
  Delta_K = DeltaK_fun(equi_sol$result_qH_extra$kappa, equi_sol$result_qL_extra$kappa, w_grid)
  Delta_w = Deltaw_fun(equi_sol$result_qH_extra$kappa, equi_sol$result_qL_extra$kappa, w_grid)
  if(shutoff_slippery_slope){
    # Ignore the impact of government intervention on the firm quality
    Delta_w = Deltaw_fun(equi_sol_no_govt$result_qH_extra$kappa, equi_sol_no_govt$result_qL_extra$kappa, w_grid)
  }
  # Initialization
  if(sum(is.na(W_vec_init))==0){
    W_vec = W_vec_init
  }else{
    W_vec = (A - I - lambda*l_dN) / (r - muK - lambda*Delta_K)
  }
  N = length(w_grid)
  diff = 1 
  diff_vec = c()
  iter_max = 1000
  iter = 0
  step_size = 0.1;  step_size_vec=c()
  while(diff>0.00001 & iter<iter_max){
    iter = iter+1
    if(shutoff_slippery_slope){
      W_vec = conspline(W_vec, w_grid, 1)$muhat
    }else{
      W_vec = predict( lm( W_vec ~ w_grid ) )
    }
    W_prime_vec = c(derivative_fun(w_grid, W_vec))
    W_post_dN = approx(w_grid, W_vec, w_grid + Delta_w, rule=2 )$y
    Dt_W = A-I + W_vec*muK + W_prime_vec*muw + lambda*(W_post_dN*(1+Delta_K)-W_vec-l_dN) - r*W_vec
    
    if( iter>=10 && sum( diff(diff_vec)[(length(diff_vec)-min(50,length(diff_vec))+1):(length(diff_vec)-1)]>=0) ==0 ){ # If the difference declines well in the last 50 rounds
      step_size = step_size*1.02
    }else{
      step_size = max(step_size/1.02, 0.001)
    }
    step_size_vec = c(step_size_vec, step_size)
    
    W_vec_new = W_vec + step_size*Dt_W   # trick: dynamically adjust the step size
    diff = sum( abs(Dt_W) )
    W_vec = W_vec_new
    diff_vec = c(diff_vec, diff)
  }
  if( !shutoff_slippery_slope & diff>0.0001 ){
    print( "Welfare function is not accurately solved!" )
  }
  W_fun <- splinefun(w_grid, W_vec)
  W_prime_fun <- splinefun(w_grid, derivative_fun(w_grid, W_vec))
  return( list(W_fun=W_fun, W_prime_fun=W_prime_fun, w_grid=w_grid, qH=qH, qL=qL, iH=iH, iL=iL, l_dN=l_dN, I=I,A=A, muw=muw, muK=muK, Delta_w=Delta_w, Delta_K=Delta_K, equi_sol=equi_sol) )
}

equi_sol = solve_other_moments(equi_sol, sim_years = 10000) # monthly simulation of 10000 years
equi_sol_no_govt =  solve_other_moments(equi_sol_no_govt, sim_years = 10000)
par$AL = AL
equi_sol$par = par

# other basic parameters
equi_sol$dt = dt
equi_sol$w_grid = w_grid
equi_sol$zeta_grid = zeta_grid

f_list = list( solve_individual_optimal_faster=solve_individual_optimal_faster, post_crisis_fraction_of_capital=post_crisis_fraction_of_capital, solve_equilibrium=solve_equilibrium, muw_fun=muw_fun, Deltaw_fun=Deltaw_fun, 
               iota_bar_fun = iota_bar_fun, Phi_fun = Phi_fun, F_fun=F_fun, f_zeta_fun=f_zeta_fun, F_zeta_fun=F_zeta_fun,
               muK_fun=muK_fun, DeltaK_fun=DeltaK_fun, Delta_output_fun=Delta_output_fun, investment_cost_fun=investment_cost_fun, l_dN_fun=l_dN_fun, Simulate_model=Simulate_model, output_drop_and_gbar=output_drop_and_gbar, solve_other_moments=solve_other_moments,
               simulation_fun=simulation_fun, solve_welfare=solve_welfare, q_residual=q_residual)

# Then report the results
return( list( key_moments = data.frame(equi_sol$key_moments[1:4], 
                                       survival_increase=(equi_sol_no_govt$liquidity_bankruptcy_total-equi_sol$liquidity_bankruptcy_total), 
                                       private_financing_over_GDP=equi_sol$private_financing/equi_sol$avg_productivity,
                                       TFP_vol_sd_over_mean=equi_sol$key_moments$TFP_vol_sd_over_mean) , 
              other_moments=equi_sol$other_moments, equi_sol=equi_sol, f_list=f_list ) )
}


# -------- Solve the optimal static welfare ----------
Optimal_static_welfare <- function( calibration_results, w_grid=seq(0,1,length.out=50), g_vec = seq( 0, par$d_bar*1.2, length.out = 200 ), plot_process=FALSE ){
  # Input: output from the function "calibration_moments".  That results contain parameters from the static government intervention and we will use the same parameters for solving the dynamic equation.
  equi_sol = calibration_results$equi_sol
  f_list = calibration_results$f_list
  par = equi_sol$par
  # Then solve for the optimal welfare function
  Nw_grid = length(w_grid)
  Ng_grid = length(g_vec)
  Welfare_matrix = as.matrix( w_grid %*% t(g_vec)  )
  Welfare_matrix_shutoff_slippery_slope = Welfare_matrix
  equi_sol_list = list()
  print( "******** Solving the optimal welfare ********" )
  for(iter in 1:Ng_grid){
    print( paste0("Progress:", round(iter/Ng_grid*100,1), "%" ) )
    par$g_bar = g_vec[iter]
    equi_sol = f_list$solve_welfare(d_bar=par$d_bar, g_bar=par$g_bar)
    Welfare_matrix[,iter] = equi_sol$W_fun(w_grid)
    equi_sol_list[[iter]] = equi_sol
    # Next, calculate the welfare if we shut off the slippery slope
    equi_sol_shutoff_slippery_slope = f_list$solve_welfare(d_bar=par$d_bar, g_bar=par$g_bar, shutoff_slippery_slope=TRUE)
    Welfare_matrix_shutoff_slippery_slope[,iter] = equi_sol_shutoff_slippery_slope$W_fun(w_grid)
  }
  g_optimal_vec = rep(0, Nw_grid)
  g_optimal_shutoff_slippery_slope_vec = rep(0, Nw_grid)
  optimal_W_vec = rep(0, Nw_grid)
  optimal_W_vec_shutoff_slippery_slope = rep(0, Nw_grid)  # This is a "perceived welfare" if we ignore the slippery slope

  par_list = list()
  for( iter_w in 1:Nw_grid ){
    w = w_grid[iter_w]
    lo <- loess(Welfare_matrix[iter_w,]~g_vec)
    g_optimal = optimize( function(g_bar) -predict(lo,g_bar), interval=c(min(g_vec),max(g_vec)) )$minimum
    g_optimal_vec[iter_w] = g_optimal
    optimal_W_vec[iter_w] = predict(lo, g_optimal)
    # Then deal with the case that shuts off the slippery slope
    lo_shutoff_slippery_slope <- loess(Welfare_matrix_shutoff_slippery_slope[iter_w,]~g_vec)
    g_optimal_shutoff_slippery_slope = optimize( function(g_bar) -predict(lo_shutoff_slippery_slope,g_bar), interval=c(min(g_vec),max(g_vec)) )$minimum
    g_optimal_shutoff_slippery_slope_vec[iter_w] = g_optimal_shutoff_slippery_slope
    optimal_W_vec_shutoff_slippery_slope[iter_w] = predict(lo, g_optimal_shutoff_slippery_slope)  # Use the actual welfare to calculate the true value.
    if(plot_process){
      MyPlot(g_vec, predict(lo,g_vec), main=paste0("welfare at w=", round(w,2)), xlab="g_bar", ylab="W" )
      abline( v = g_optimal, col="red", lwd=2 )
      text( g_optimal+0.04, quantile(Welfare_matrix[iter_w,],0.3), "g_bar*", col="red" )
      Sys.sleep(0.1)
    }
  }
  # MyPlot(w_grid,  data.frame(g_optimal_vec,g_optimal_shutoff_slippery_slope_vec)  )
  return( list( w_grid=w_grid, g_bar_optimal=g_optimal_vec, W_optimal=optimal_W_vec, Welfare_matrix=Welfare_matrix, equi_sol_list = equi_sol_list, g_vec=g_vec,
                g_optimal_shutoff_slippery_slope_vec=g_optimal_shutoff_slippery_slope_vec, W_optimal_shutoff_slippery_slope=optimal_W_vec_shutoff_slippery_slope) )
}



# ------------- Dynamic g_bar(w) -------------------------
# Important note: in this case, the capital price qH and qL will be w-dependent. 
# Solve for equilibrium of dynamic gamma, GIVEN the dynamic gamma function gamma(w)
solve_for_equilibrium_with_dynamic_g <- function(calibration_results, g_fun = approxfun(c(0,1),c(0.15,0.25)) , w_grid=seq(0.0001,0.9999,length.out = 20), step_multiplier=1.1, accuracy=0.002, iter_max = 20, step_size_default=0.5){
  # g_fun is the government intervention policy. 
  equi_sol = calibration_results$equi_sol
  f_list = calibration_results$f_list
  par = equi_sol$par
  F_fun = f_list$F_fun
  
  # Step 1: initialize the q(w) function by solving q(w) at each w separately and then fit a function
  # We need to solve for the qH and qL functions
  print( "*** Solving dynamic q(w): Initialization ***" )
  w_grid_init = seq(0,1,length.out=25)   # This initialization grid can be coaser than the main grid.
  init_vec = rep(0, length(w_grid_init))
  TB_init = data.frame( w=w_grid_init, g_bar=g_fun(w_grid_init), qH=init_vec, qL=init_vec)
  capital_pricing_residual = f_list$q_residual
  for( iter in 1:nrow(TB_init) ){
    w = TB_init$w[iter]
    g_bar = TB_init$g_bar[iter]
    q_sol = fsolve( function(q) capital_pricing_residual(q, d_bar=par$d_bar, g_bar),  c(equi_sol$qH, equi_sol$qL) )
    qH = q_sol$x[1]
    qL = q_sol$x[2]
    TB_init[iter,] = data.frame( w, g_bar, qH, qL )
  }
  # Step 2: use the false time derivative method to account for drifts and jumps of q(w) and solve the function globally.
  qH_vec = approx( TB_init$w, TB_init$qH, w_grid)$y
  qL_vec = approx( TB_init$w, TB_init$qL, w_grid)$y
  w_post_vec = w_grid
  N = length(w_grid)
  diff = 1;  diff_vec = c()
  iter=0
  step_size_vec=c()
  step_size = step_size_default
  init = rep(0, length(w_grid))
  print( "*** Solving dynamic q(w): Iteration ***" )
  while( diff > accuracy && iter<iter_max ){
    print( paste0("iteraion ", iter, ": diff= ", round(diff,3), ", step size=", round(step_size,3)) )
    if(iter>3){
      if( diff_vec[length(diff_vec)] > diff_vec[length(diff_vec)-1] ){
        break;
      }
    }
    # With the qH and qL, now we should update the Delta H and Delta L
    # Note: the profit function is dependent on the post-crisis capital price and we need a fixed-point algorithm. 
    i_bar_H_vec = f_list$iota_bar_fun(qH_vec) # investment 
    i_bar_L_vec = f_list$iota_bar_fun(qL_vec) # investment 
    muw_vec = f_list$muw_fun(i_bar_H_vec, i_bar_L_vec, w_grid)  # Drift of w
    qH_prime_vec = c(derivative_fun(w_grid, qH_vec))
    qL_prime_vec = c(derivative_fun(w_grid, qL_vec))
    mu_qH_vec = qH_prime_vec * muw_vec  # Drift of qH(w)
    mu_qL_vec = qL_prime_vec * muw_vec  # Drift of qL(w)
    
    crisis_profit_H_vec = rep(0,length(w_grid))
    crisis_profit_L_vec = rep(0,length(w_grid))
    qH_post_vec = rep(0,length(w_grid))
    qL_post_vec = rep(0,length(w_grid))
    for(iter_w in 1:length(w_grid)){
      print( paste0("Progress:", round(iter_w/length(w_grid)*100,1), "%" ) )
      w = w_grid[iter_w]
      qH = qH_vec[iter_w];  qL = qL_vec[iter_w]
      # Need to solve the fixed-point problem. Firm decision relies on post-crisis capital price, which determines dw and the post-crisis state.
      residual_post_crisis_q <- function( w_post ){
        qH_post = approx(w_grid, qH_vec, w_post)$y  # Evaluate qH after the shock and then use it for the crisis profit function equation. 
        qL_post = approx(w_grid, qL_vec, w_post)$y
        result_H = f_list$post_crisis_fraction_of_capital(qH_post, par$d_bar, g_fun(w), type="H")
        result_L = f_list$post_crisis_fraction_of_capital(qL_post, par$d_bar, g_fun(w), type="L")
        w_post_new = w + f_list$Deltaw_fun(result_H$kappa, result_L$kappa, w)
        return( w_post - w_post_new )
      }
      range = c( max(w-0.1,0.0001), min(w+0.2,0.9999) )
      if( residual_post_crisis_q(range[1]) * residual_post_crisis_q(range[2]) > 0 ){
        # if the two end points return the same sign on the residual (this happens at the boundary)
        w_post = w
      }else{
        root_sol = uniroot( residual_post_crisis_q, range )
        w_post = root_sol$root
      }
      # Then assign the value of post-shock qH and qL
      qH_post = approx(w_grid, qH_vec, w_post)$y
      qL_post = approx(w_grid, qL_vec, w_post)$y
      result_H = f_list$post_crisis_fraction_of_capital(qH_post, par$d_bar, g_fun(w), type="H")
      result_L = f_list$post_crisis_fraction_of_capital(qL_post, par$d_bar, g_fun(w), type="L")
      
      w_post_vec[iter_w] = w_post
      qH_post_vec[iter_w] = qH_post
      qL_post_vec[iter_w] = qL_post
      crisis_profit_H_vec[iter_w] = integrate( function(zeta) f_list$f_zeta_fun(zeta,type="H") * (result_H$obj_optimal_fun(zeta) - qH), lower = 0, upper=zeta_upper  )$value
      crisis_profit_L_vec[iter_w] = integrate( function(zeta) f_list$f_zeta_fun(zeta,type="L") * (result_L$obj_optimal_fun(zeta) - qL), lower = 0, upper=zeta_upper  )$value
    }    
    # MyPlot(w_grid, data.frame(qH_vec, qH_post_vec)  )
    # MyPlot(w_grid, data.frame(qL_vec, qL_post_vec)  )
    
    Dt_qH = par$r - (  mu_qH_vec + (par$AH-f_list$Phi_fun(i_bar_H_vec))/qH_vec + (i_bar_H_vec - par$delta) + par$lambda*crisis_profit_H_vec/qH_vec    )
    Dt_qL = par$r - (  mu_qL_vec + (par$AL-f_list$Phi_fun(i_bar_L_vec))/qL_vec + (i_bar_L_vec - par$delta) + par$lambda*crisis_profit_L_vec/qL_vec   )
    
    if( iter>=2 && diff_vec[length(diff_vec)]<diff_vec[length(diff_vec)-1] ){ # If the difference declines well
      step_size = step_size*step_multiplier
    }else{
      step_size = step_size/step_multiplier
    }
    step_size_vec = c(step_size_vec, step_size)
    qH_vec_new = qH_vec - step_size*Dt_qH   # trick: dynamically adjust the step size
    qL_vec_new = qL_vec - step_size*Dt_qL

    diff = sum( abs(Dt_qH) + abs(Dt_qL) )
    qH_vec = qH_vec_new
    qL_vec = qL_vec_new
    diff_vec = c(diff_vec, diff)
    iter = iter+1
    if(iter>3){
      MyPlot( 1:iter, diff_vec, xlab="iteration", ylab="diff from previous round" )
    }
    # **** Important: supsmu is better than conspline.  Use smoothing in each iteration is better than the "delayed smoothing".
    qH_vec = supsmu(w_grid,qH_vec)$y
    qL_vec = supsmu(w_grid,qL_vec)$y
    #   result = conspline(qH_vec, w_grid, 1)
    #   qH_vec = result$muhat  # Fitted values
    #   result = conspline(qL_vec, w_grid, 1)
    #   qL_vec = result$muhat  # Fitted values
  }
  
  if( diff > accuracy ){
    print( "qH_vec and qL_vec not accurately solved for gamma(w)!" )
  }
  
  # Step 4: update everything else
  init_vec = rep(0, length(w_grid))
  TB = data.frame( w=w_grid, g_bar = g_fun(w_grid), w_post=w_post_vec, qH=qH_vec, qL=qL_vec, qH_post=qH_post_vec, qL_post=qL_post_vec, 
                   crisis_profit_H=crisis_profit_H_vec, crisis_profit_L=crisis_profit_L_vec,  iH=f_list$iota_bar_fun(qH_vec), iL=f_list$iota_bar_fun(qL_vec),
                   kappa_H=init_vec,  kappa_L=init_vec, Delta_K=init_vec, Delta_w=init_vec,
                   muK=init_vec, muw=init_vec,  l_dN = init_vec,
                   productivity=init_vec, govt_credit_H=init_vec, govt_credit_L=init_vec, govt_credit=init_vec )
  for( iter in 1:length(w_grid) ){
    w = w_grid[iter]
    result_H = f_list$post_crisis_fraction_of_capital(TB$qH_post[iter], par$d_bar, g_fun(w), type="H")
    result_L = f_list$post_crisis_fraction_of_capital(TB$qL_post[iter], par$d_bar, g_fun(w), type="L")
    kappa_H = result_H$kappa
    kappa_L = result_L$kappa
    Delta_K = f_list$DeltaK_fun(kappa_H, kappa_L, w)
    Delta_w = f_list$Deltaw_fun(kappa_H, kappa_L, w)
    # Important: consider entry in the drifts.
    muK = f_list$muK_fun( TB$iH[iter], TB$iL[iter], w )
    muw = f_list$muw_fun( TB$iH[iter], TB$iL[iter], w  )
    productivity = w*par$AH + (1-w)*par$AL
    govt_credit_H = result_H$govt_sector_financing
    govt_credit_L = result_L$govt_sector_financing
    govt_credit = w*govt_credit_H + (1-w)*govt_credit_L
    l_dN =  w*f_list$l_dN_fun(result_H$l_fun) + (1-w)*f_list$l_dN_fun(result_L$l_fun)
    TB[iter,] = data.frame(  w, g_bar=g_fun(w), w_post=w_post_vec[iter], qH=qH_vec[iter], qL=qL_vec[iter], qH_post=qH_post_vec[iter], qL_post=qL_post_vec[iter], 
                             crisis_profit_H=crisis_profit_H_vec[iter], crisis_profit_L=crisis_profit_L_vec[iter],  iH=f_list$iota_bar_fun(qH), iL=f_list$iota_bar_fun(qL),
                             kappa_H=kappa_H,  kappa_L=kappa_L, Delta_K=Delta_K, Delta_w=Delta_w,
                             muK=muK, muw=muw,  l_dN=l_dN, 
                             productivity=productivity, govt_credit_H=govt_credit_H, govt_credit_L=govt_credit_L, govt_credit=govt_credit   )
  }
  return( list(TB=TB, diff_vec=diff_vec, step_size_vec=step_size_vec, g_fun=g_fun, par=par, f_list=f_list) )
}


# Given g_bar(w), solve the welfare of the economy. 
solve_welfare_dynamic_gbar <- function(result_solving_dynamic_equil, calibration_results,  accuracy=0.002, step_multiplier=1.02, w_grid=seq(0.001,0.999,length.out=100)){
  # TB is the table that stores many different outputs and it is a result generated by solve_for_equilibrium_with_dynamic_g. 
  # Initialization 
  N = length(w_grid)
  TB = result_solving_dynamic_equil$TB  # Results formatted in table
  g_fun = result_solving_dynamic_equil$g_fun # government intervention function
  par = calibration_results$equi_sol$par   # baseline parameters
  f_list = calibration_results$f_list
  
  result_w0 = f_list$solve_welfare( w_grid, d_bar=par$d_bar, g_bar=g_fun(0) );
  result_w1 = f_list$solve_welfare( w_grid, d_bar=par$d_bar, g_bar=g_fun(1) );
  W0 = result_w0$W_fun(w_grid[1]);  W1=result_w1$W_fun(w_grid[length(w_grid)])
  W_vec = seq(W0, W1, length.out = N)
  A = w_grid*par$AH + (1-w_grid)*par$AL
  iH = approx( TB$w, TB$iH, w_grid )$y 
  iL = approx( TB$w, TB$iL, w_grid )$y 
  I = f_list$investment_cost_fun(iH, iL, w_grid)   # normal time investment
  muw_vec = approx( TB$w, TB$muw, w_grid )$y
  muK_vec = approx( TB$w, TB$muK, w_grid )$y
  Delta_w_vec = approx( TB$w, TB$Delta_w, w_grid )$y
  Delta_K_vec = approx( TB$w, TB$Delta_K, w_grid )$y
  l_dN_vec = approx( TB$w, TB$l_dN, w_grid )$y
  
  diff = 1 
  diff_vec = c()
  iter_max = 1000
  iter = 0
  step_size = 0.1;  step_size_vec=c()
  while(diff>0.00001 & iter<iter_max){
    iter = iter+1
    W_vec = conspline(W_vec, w_grid, 1)$muhat
    W_prime_vec = c(derivative_fun(w_grid, W_vec))
    W_post_dN = approx(w_grid, W_vec, w_grid + Delta_w_vec, rule=2 )$y
    Dt_W = A-I + W_vec*muK_vec + W_prime_vec*muw_vec + par$lambda*(W_post_dN*(1+Delta_K_vec)-W_vec-l_dN_vec) - par$r*W_vec
    
    if( iter>=10 && sum( diff(diff_vec)[(length(diff_vec)-min(50,length(diff_vec))+1):(length(diff_vec)-1)]>=0) ==0 ){ # If the difference declines well in the last 50 rounds
      step_size = step_size*1.02
    }else{
      step_size = max(step_size/1.02, 0.0001)
    }
    step_size_vec = c(step_size_vec, step_size)
    
    W_vec_new = W_vec + step_size*Dt_W   # trick: dynamically adjust the step size
    diff = sum( abs(Dt_W) )
    W_vec = W_vec_new
    diff_vec = c(diff_vec, diff)
  }
  if( diff>0.01 ){
    print( "Welfare function is not accurately solved!" )
  }
  W_fun <- splinefun(w_grid, W_vec)
  W_prime_fun <- splinefun(w_grid, derivative_fun(w_grid, W_vec))
  return( list(W_fun=W_fun, W_prime_fun=W_prime_fun, w_grid=w_grid, iH=iH, iL=iL, l_dN=l_dN_vec, I=I,A=A, muw=muw_vec, muK=muK_vec, Delta_w=Delta_w_vec, Delta_K=Delta_K_vec, diff_vec=diff_vec, step_size_vec=step_size_vec) )
}


# ------- Optimal Welfare under Dynamic g_bar(w) -----
solve_optimal_welfare_dynamic_gbar <- function(calibration_results, result_optimal_static_W, w_grid = seq(0,1,length.out = 50), step_multiplier=1.1, accuracy=0.003, accuracy_equilibrium_solution=0.0001, iter_max = 12, inner_max=200){
  # Note: calibration_results is an output of the function ""calibration_mooments". 
  # result_optimal_W is the optimal static welfare maximizing result. 
  # Initialize the solution process with the following g_bar
  par = calibration_results$equi_sol$par
  f_list = calibration_results$f_list
  # This is the initialization of the optimal g_bar vector, using the static optimal result
  g_vec_optimal = approx( result_optimal_static_W$w_grid, result_optimal_static_W$g_bar_optimal, w_grid )$y
  g_init_fun = approxfun(result_optimal_static_W$w_grid, result_optimal_static_W$g_bar_optimal)
  
  # While loop
  diff = 1;  diff_vec = c()
  iter=0
  while( diff > accuracy & iter<iter_max ){
    print( paste0( "------- Solving for optimal welfare with dynamic g_bar: iteration ", iter, " -------" ) )
    result = solve_for_equilibrium_with_dynamic_g(calibration_results, g_fun = approxfun(w_grid,g_vec_optimal))
    result_welfare = solve_welfare_dynamic_gbar(result, calibration_results)
    W_fun = result_welfare$W_fun   # welfare function
    TB = result$TB
    g_vec_prev = g_vec_optimal  # Store the previous round. 
    
    obj_fun <- function(g_bar, w, num_accuracy=10^(-8)){
      qH = approx( TB$w, TB$qH, w, rule=2 )$y
      qL = approx( TB$w, TB$qL, w, rule=2 )$y
      resultH = f_list$solve_individual_optimal_faster(qH, d_bar=par$d_bar, g_bar=g_bar)
      resultL = f_list$solve_individual_optimal_faster(qL, d_bar=par$d_bar, g_bar=g_bar)
      l_dN = integrate( function(zeta) w*resultH$l_fun(zeta)*f_list$f_zeta_fun(zeta,type="H") + (1-w)*resultL$l_fun(zeta)*f_list$f_zeta_fun(zeta,type="L") , lower=0, upper=zeta_upper, abs.tol=num_accuracy )$value
      kappaH = integrate( function(zeta) f_list$F_fun(zeta+resultH$l_fun(zeta))*f_list$f_zeta_fun(zeta,type="H"), lower=0, upper=zeta_upper , abs.tol=num_accuracy)$value
      kappaL = integrate( function(zeta) f_list$F_fun(zeta+resultL$l_fun(zeta))*f_list$f_zeta_fun(zeta,type="L"), lower=0, upper=zeta_upper , abs.tol=num_accuracy)$value
      Delta_K = f_list$DeltaK_fun(kappaH,kappaL,w)
      Delta_w=  f_list$Deltaw_fun(kappaH,kappaL,w)
      obj = W_fun( w+Delta_w )*(1+Delta_K) - l_dN - W_fun(w)
      return( obj  )
    }
    obj_vec_fun <- function(g_bar_vec, w, num_accuracy=10^(-8)){
      obj_vec = rep(0, length(g_bar_vec))
      for(i in 1:length(g_bar_vec)){
        obj_vec[i] = obj_fun(g_bar_vec[i], w, num_accuracy)
      }
      return(obj_vec)
    }
    
    # g_vec = seq(0,0.5,length.out=length(w_grid))
    # MyPlot( g_vec, -obj_vec_fun(g_vec, 1) )     # At w=1, it is not a solution for the business. 
    
    for(i in 1:length(w_grid)){
      print( paste0("Progress: ", round(i/length(w_grid)*100,1), "%" ) )
      solution = optimize( function(g_bar) -obj_vec_fun(g_bar, w_grid[i], num_accuracy=10^(-12)), interval= c(0,par$AH) )
      g_vec_optimal[i] = solution$minimum
    }
    g_vec_optimal[length(w_grid)] = g_vec_optimal[length(w_grid)-1]  # At w=1, the solution will go wild because the objective function is flat for large g_bar
    # Then do some smoothing on the solution. Numerical accuracy is challenging. 
    g_vec_optimal_smoothed = loess( g_vec_optimal~w_grid )$fitted
    g_vec_optimal_smoothed[1] = result_optimal_static_W$g_bar_optimal[1]  # Note: at w=0, the two policies always coincide. Imposing the condition increases numerical accuracy.
    g_vec_optimal = g_vec_optimal_smoothed
    # MyPlot(w_grid, data.frame(g_vec_optimal, g_vec_optimal_smoothed))
    diff = sum( abs( g_vec_optimal - g_vec_prev ) )
    print( paste0( "CURRENT iteration error: ", diff ) )
    diff_vec = c(diff_vec, diff)
    iter = iter+1
  }
  if(diff > accuracy){
    print( "Optimal welfare NOT accurately solved!" )
  }
  # Then solve for the solution and welfare. 
  result = solve_for_equilibrium_with_dynamic_g(calibration_results, g_fun = approxfun(w_grid,g_vec_optimal))
  result_welfare = solve_welfare_dynamic_gbar(result, calibration_results)
  
  return( list(w_grid=w_grid, g_vec_optimal=g_vec_optimal, result_dynamic_equi=result, result_welfare=result_welfare, diff_vec=diff_vec) )
}


simulation_dynamic_gbar_fun <- function(equi_dynamic, years_sim=50, seed=1,  dt=1, dN_vec=NA, w_init=w0){
  # Note: par has to be after being processed by "solve_for_equilibrium"
  set.seed(seed)
  par = equi_dynamic$par
  f_list = equi_dynamic$f_list
  Delta_output_fun = f_list$Delta_output_fun
  TB = equi_dynamic$TB
  iH_fun = approxfun(TB$w, TB$iH)
  iL_fun = approxfun(TB$w, TB$iL)
  kappaH_fun = approxfun(TB$w, TB$kappa_H)
  kappaL_fun = approxfun(TB$w, TB$kappa_L)
  qH_fun = approxfun(TB$w, TB$qH)
  qL_fun = approxfun(TB$w, TB$qL)
  N_sim = round(years_sim/dt)
  if(sum(!is.na(dN_vec))==0){
    dN_vec = as.numeric(runif(N_sim)<par$lambda*dt) 
  }
  w_vec = rep(w_init, N_sim)
  K_vec = rep(K0, N_sim)
  for(iter in 2:N_sim){
    w = w_vec[iter-1]
    K = K_vec[iter-1]
    muw = f_list$muw_fun(iH_fun(w),iL_fun(w),w)
    muK = f_list$muK_fun(iH_fun(w),iL_fun(w),w)
    Deltaw = f_list$Deltaw_fun(kappaH_fun(w),kappaL_fun(w),w)
    DeltaK = f_list$DeltaK_fun(kappaH_fun(w),kappaL_fun(w),w)
    w_vec[iter] = w + muw*dt + Deltaw*dN_vec[iter-1] # The timing here is important for the calculations of welfare below. 
    K_vec[iter] = K * ( 1 + muK*dt +  DeltaK*dN_vec[iter-1] )
  }
  TB_sim = as.data.table(data.frame( T=(1:N_sim)*dt,  w=w_vec, K=K_vec, dN=dN_vec))
  
  TB_sim[  , output_per_K:= par$AH*w+par$AL*(1-w) ]  # Per unit of capital productivity
  TB_sim[  , GDP:= output_per_K*K ]  # Per unit of capital
  TB_sim[  , qH := qH_fun(w) ]
  TB_sim[  , qL := qL_fun(w) ]
  TB_sim[  , q := qH*w + qL*(1-w) ]
  TB_sim[  , iH := iH_fun(w) ]
  TB_sim[  , iL := iL_fun(w) ]
  TB_sim[  , normal_investment_per_K:= w*iH+(1-w)*iL ]  # Note: translate lumpy investment into flow-based ones. 
  TB_sim[  , normal_investment_to_output:= normal_investment_per_K/output_per_K ]
  TB_sim[  , normal_consumption_per_K := output_per_K-normal_investment_per_K ]
  TB_sim[  , normal_consumption_to_output := normal_consumption_per_K/output_per_K ]
  TB_sim[  , KH:= w*K ]
  TB_sim[  , KL:= (1-w)*K ]
  TB_sim[  , GDP_drop_in_crisis:= Delta_output_fun( kappaH_fun(w), kappaL_fun(w), w ) ]
  # for(iter in 1:nrow(TB_sim)){
  #   TB_sim$GDP_drop_in_crisis[iter] = Delta_output_fun( kappaH_fun(w) )
  # }
  TB_sim[ , g_bar :=  equi_dynamic$g_fun(w) ]
  TB_sim[ , govt_credit :=  approx( TB$w, TB$govt_credit, w )$y ]
  return(TB_sim)
}


